import networkx as nx
import numpy as np
import scipy
import torch
from scipy.sparse import *

def zero_prior(G):
    prior = {}
    for e in G.edges():
        prior[e] = 0.

    return prior

def lsq_matrix_flow(G, train_flows):
    '''
        Generates matrix A and vector b to solve flow problem
        as least-squares, also returns index for recovering
        edges from solution.
    '''
    B = nx.incidence_matrix(G, oriented=True)
    f0 = np.zeros(G.number_of_edges())
    
    i = 0
    for e in G.edges():
        if e in train_flows:
            f0[i] = train_flows[e]
        
        i = i + 1

    b = -B.dot(f0)
    
    row = []
    col = []
    data = []
    
    index = {}
    
    i = 0
    for ei in G.edges():
        j = 0
        for ej in G.edges():
            if ej not in train_flows:
                if ei == ej:
                    row.append(i)
                    col.append(j)
                    data.append(1.)
                    
                    index[ej] = j

                j = j + 1
        i = i + 1

    sigma = coo_matrix((data, (row,col)), shape=(G.number_of_edges(), G.number_of_edges()-len(train_flows)))
    
    A = B.dot(sigma).tocoo()
    
    return A, b, index

def get_dict_flows_from_tensor(index, x, test_flows):
    '''
        Extract x as a dictionary from tensor.
    '''
    flows = {}
    for e in test_flows:
        flows[e] = x[index[e],0].item()
            
    return flows

def initialize_flows(n_edges, zeros=False):
    '''
    '''
    use_cuda = torch.cuda.is_available()
        
    if zeros is True:
        flows = np.zeros((n_edges,1))
    else:
        flows = np.random.random((n_edges,1))
        flows = flows / np.max(flows)

    if use_cuda:
        return torch.tensor(flows, requires_grad=True, dtype=torch.float, device='cuda:0')
    else:
        return torch.tensor(flows, requires_grad=True, dtype=torch.float)
